{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "L5tGkC_j050w"
   },
   "source": [
    "# Tokenisation *Byte-Pair Encoding*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fjmzCwMq0504"
   },
   "source": [
    "Installez les bibliothèques 🤗 *Transformers* et 🤗 *Datasets* pour exécuter ce *notebook*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "D75bQauc0505"
   },
   "outputs": [],
   "source": [
    "!pip install datasets transformers[sentencepiece]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "s4G8-gDG0506"
   },
   "outputs": [],
   "source": [
    "corpus = [\n",
    "    \"C'est le cours d'Hugging Face.\",\n",
    "    \"Ce chapitre traite de la tokenisation.\",\n",
    "    \"Cette section présente plusieurs algorithmes de tokenizer.\",\n",
    "    \"Avec un peu de chance, vous serez en mesure de comprendre comment ils sont entraînés et génèrent des tokens.\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "CRb5o8vS0507"
   },
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"asi/gpt-fr-cased-small\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "K_szN_Y-0508"
   },
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "word_freqs = defaultdict(int)\n",
    "\n",
    "for text in corpus:\n",
    "    words_with_offsets = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)\n",
    "    new_words = [word for word, offset in words_with_offsets]\n",
    "    for word in new_words:\n",
    "        word_freqs[word] += 1\n",
    "\n",
    "print(word_freqs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ko-UJpnt0508"
   },
   "outputs": [],
   "source": [
    "alphabet = []\n",
    "\n",
    "for word in word_freqs.keys():\n",
    "    for letter in word:\n",
    "        if letter not in alphabet:\n",
    "            alphabet.append(letter)\n",
    "alphabet.sort()\n",
    "\n",
    "print(alphabet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IR_iQ6Vq0509"
   },
   "outputs": [],
   "source": [
    "vocab = [\"<|endoftext|>\"] + alphabet.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "u5Avma5F050-"
   },
   "outputs": [],
   "source": [
    "splits = {word: [c for c in word] for word in word_freqs.keys()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PVHu72DI050_"
   },
   "outputs": [],
   "source": [
    "def compute_pair_freqs(splits):\n",
    "    pair_freqs = defaultdict(int)\n",
    "    for word, freq in word_freqs.items():\n",
    "        split = splits[word]\n",
    "        if len(split) == 1:\n",
    "            continue\n",
    "        for i in range(len(split) - 1):\n",
    "            pair = (split[i], split[i + 1])\n",
    "            pair_freqs[pair] += freq\n",
    "    return pair_freqs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "48S6KmcB050_"
   },
   "outputs": [],
   "source": [
    "pair_freqs = compute_pair_freqs(splits)\n",
    "\n",
    "for i, key in enumerate(pair_freqs.keys()):\n",
    "    print(f\"{key}: {pair_freqs[key]}\")\n",
    "    if i >= 5:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "kVgFt81n051A"
   },
   "outputs": [],
   "source": [
    "best_pair = \"\"\n",
    "max_freq = None\n",
    "\n",
    "for pair, freq in pair_freqs.items():\n",
    "    if max_freq is None or max_freq < freq:\n",
    "        best_pair = pair\n",
    "        max_freq = freq\n",
    "\n",
    "print(best_pair, max_freq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NVuJAlmU051B"
   },
   "outputs": [],
   "source": [
    "merges = {(\"Ġ\", \"t\"): \"Ġt\"}\n",
    "vocab.append(\"Ġt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NWhbhqJU051C"
   },
   "outputs": [],
   "source": [
    "def merge_pair(a, b, splits):\n",
    "    for word in word_freqs:\n",
    "        split = splits[word]\n",
    "        if len(split) == 1:\n",
    "            continue\n",
    "\n",
    "        i = 0\n",
    "        while i < len(split) - 1:\n",
    "            if split[i] == a and split[i + 1] == b:\n",
    "                split = split[:i] + [a + b] + split[i + 2 :]\n",
    "            else:\n",
    "                i += 1\n",
    "        splits[word] = split\n",
    "    return splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "rJN4TLmm051C"
   },
   "outputs": [],
   "source": [
    "splits = merge_pair(\"Ġ\", \"t\", splits)\n",
    "splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ElXK18p5051D"
   },
   "outputs": [],
   "source": [
    "vocab_size = 50\n",
    "\n",
    "while len(vocab) < vocab_size:\n",
    "    pair_freqs = compute_pair_freqs(splits)\n",
    "    best_pair = \"\"\n",
    "    max_freq = None\n",
    "    for pair, freq in pair_freqs.items():\n",
    "        if max_freq is None or max_freq < freq:\n",
    "            best_pair = pair\n",
    "            max_freq = freq\n",
    "    splits = merge_pair(*best_pair, splits)\n",
    "    merges[best_pair] = best_pair[0] + best_pair[1]\n",
    "    vocab.append(best_pair[0] + best_pair[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ehoZwJgB051D"
   },
   "outputs": [],
   "source": [
    "print(merges)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Fn5zl1AC051E"
   },
   "outputs": [],
   "source": [
    "print(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tCCJYkXU051E"
   },
   "outputs": [],
   "source": [
    "def tokenize(text):\n",
    "    pre_tokenize_result = tokenizer._tokenizer.pre_tokenizer.pre_tokenize_str(text)\n",
    "    pre_tokenized_text = [word for word, offset in pre_tokenize_result]\n",
    "    splits = [[l for l in word] for word in pre_tokenized_text]\n",
    "    for pair, merge in merges.items():\n",
    "        for idx, split in enumerate(splits):\n",
    "            i = 0\n",
    "            while i < len(split) - 1:\n",
    "                if split[i] == pair[0] and split[i + 1] == pair[1]:\n",
    "                    split = split[:i] + [merge] + split[i + 2 :]\n",
    "                else:\n",
    "                    i += 1\n",
    "            splits[idx] = split\n",
    "\n",
    "    return sum(splits, [])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ri4HzbPD051E"
   },
   "outputs": [],
   "source": [
    "tokenize(\"Ce n'est pas un token.\")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
